from .hdf5s_dataloader import HDF5Dataset
import numpy as np
import pandas as pd
import torch
from torch.utils.data.dataloader import DataLoader


def get_dataloaders_prostate_mri(settings):
    """Create train, val and test datasets."""

    # Split train and val sets and load them
    label_df_train = pd.read_csv(settings.label_csv)
    subject_ids_train = label_df_train["study_id"].values
    subject_labels_train = label_df_train["DCE_Final"].values
    (
        training_samples,
        training_labels,
        validation_samples,
        validation_labels,
    ) = dataset_split_train_val(subject_ids_train, subject_labels_train, 0.85, 1)

    (
        training_samples,
        training_labels,
        validation_samples,
        validation_labels,
    ) = dataset_split_train_val(subject_ids_train, subject_labels_train, 0.85, 1)

    train_dataset = HDF5Dataset(
        settings.hdf5_dir, training_samples, training_labels, "train", None
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=settings.batch_size, shuffle=True
    )
    val_dataset = HDF5Dataset(
        settings.hdf5_dir, validation_samples, validation_labels, "val", None
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=settings.batch_size, shuffle=False
    )

    # Load test set
    label_df_test = pd.read_csv(settings.labels_csv_test)
    test_samples = label_df_test["UUID"].values
    test_labels = label_df_test["DCE_Final"].values
    test_dataset = HDF5Dataset(
        settings.hdf5_dir_test, test_samples, test_labels, "test", None
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=settings.batch_size, shuffle=False
    )
    return train_loader, val_loader, test_loader


def dataset_split_train_val(
    subject_ids, subject_label, perc_training_set, fixed_val_set
):
    """
    Function used to split train/val differently at each run,
    """
    positive_subjects = subject_ids[subject_label.astype(np.bool)]
    negative_subjects = subject_ids[~subject_label.astype(np.bool)]
    num_positive_subjects = positive_subjects.size
    num_negative_subjects = negative_subjects.size
    # if fixed_val_set == 1:
    #     np.random.seed(42)
    positive_permuted_rows = np.random.permutation(num_positive_subjects)
    negative_permuted_rows = np.random.permutation(num_negative_subjects)
    # print("num samples neg and pos init", num_negative_subjects, num_positive_subjects)
    positive_permuted_rows = positive_permuted_rows[:num_positive_subjects]
    # print("num samples final init", num_positive_subjects)

    positive_training_rows = positive_permuted_rows[
        : np.int(num_positive_subjects * perc_training_set)
    ]
    negative_training_rows = negative_permuted_rows[
        : np.int(num_negative_subjects * perc_training_set)
    ]
    training_samples = np.concatenate(
        (
            positive_subjects[positive_training_rows],
            negative_subjects[negative_training_rows],
        )
    )
    training_labels = np.concatenate(
        (np.ones(positive_training_rows.size), np.zeros(negative_training_rows.size))
    )
    training_permuted_rows = np.random.permutation(training_samples.size)
    training_samples = training_samples[training_permuted_rows]
    training_labels = training_labels[training_permuted_rows]

    # print("Train", positive_subjects[positive_training_rows].size, negative_subjects[negative_training_rows].size)

    positive_validation_rows = positive_permuted_rows[
        np.int(num_positive_subjects * perc_training_set) :
    ]
    negative_validation_rows = negative_permuted_rows[
        np.int(num_negative_subjects * perc_training_set) :
    ]
    validation_samples = np.concatenate(
        (
            positive_subjects[positive_validation_rows],
            negative_subjects[negative_validation_rows],
        )
    )
    validation_labels = np.concatenate(
        (
            np.ones(positive_validation_rows.size),
            np.zeros(negative_validation_rows.size),
        )
    )
    validation_permuted_rows = np.random.permutation(validation_samples.size)
    validation_samples = validation_samples[validation_permuted_rows]
    validation_labels = validation_labels[validation_permuted_rows]

    # print("Val", positive_subjects[positive_validation_rows].size, negative_subjects[negative_validation_rows].size)
    return training_samples, training_labels, validation_samples, validation_labels
